Prediction API – Machine Learning from Google


Introduction

One of the exciting APIs among the 50+ APIs offered by Google is the Prediction API. It provides pattern matching and machine learning capabilities like recommendations or categorization. The notion is similar to the machine learning capabilities that we can see in other solutions (e.g. in Apache Mahout): we can train the system with a set of training data and then the applications based on Prediction API can recommend (“predict”) what products the user might like or  they can categories spams, etc.

In this post we go through an example how to categorize SMS messages – whether they are spams or valuable texts (“hams”).

Using Prediction API

In order to be able to use Prediction API, the service needs to be enabled via Google API console. To upload training data, Prediction API also requires Google Cloud Storage.

The dataset  used in this post is from UCI Machine Learning Repository.  UCI Machine Learning repository has 235 datasets publicly available, this post is based on SMS Spam Collections dataset.

To upload the training data first we need to create a bucket in Google Cloud Storage. From Google API console we need to click on Google Cloud Storage and then on Google Cloud Storage Manager: This will open a webpage whe we can create new buckets and upload or delete files.

GoogleStorage2

The UCI SMS Spam Collection file is not suitable as is for Prediction API, it needs to be converted into the following format (the categories – ham/spam – need to be quoted as well as the SMS text):

“ham” “Go until jurong point, crazy.. Available only in bugis n great world la e buffet… Cine there got amore wat…”

GoogleStorage4

Google Prediction API offers a handful of commands that can be invoked via REST interface. The simplest way of testing Prediction API is to use Prediction API explorer.

GooglePrediction1

Once the training data is available on Google Cloud Storage, we can start training the machine learning system behind Prediction API. To begin training our model, we need to run prediction.trainedmodels.insert. All commands require authentication, it is based on OAuth 2.0 standard.

GooglePrediction2

In the insert menu we need to specify the fields that we want to be included in the response.  In the request body we need to define an id (this will be used as a reference to the model in the commands used later on), a storageDataLocation where we have the training data uploaded (the Google Cloud Storage path) and the modelType (could be regression or classification, for spam filtering it is classification):

GooglePrediction-SpamInsert1

The training runs for a while, we can check the status using prediction.trainedmodels.get command. The status field is going to be RUNNING and then will be changed to DONE, once the training is finished.

GooglePrediction-SpamGet1

GooglePrediction-SpamGet2

Now we are ready to run our test against the machine learning system and it is going to classify whether the given text is spam or ham. The Prediction API command for this action is prediction.trainedmodels.predict. In the id field we have to refer to the id that we defined for the  prediction.trainedmodels.insert command (bighadoop-00001) and we also need to specify the request body – input will be csvInstance and then we enter the text that we want to get categorized (e.g. “Free entry”)

GooglePrediction-SpamPredict1

The system then returns with the category (spam) and the score (0.822158 for spam, 0.177842 for ham):

GooglePrediction-SpamPredict2

Google Prediction API libraries

Google also offers a featured sample application that includes all the code required to run it on Google App Engine. It is called Try-Prediction and the code is written in Python and also in Java. The application can be tested at http://try-prediction.appspot.com.

For instance, if we enter a quote for the Language Detection model from Niels Bohr: “Prediction is very difficult, especially if it’s about the future.”, it will return that it is likely to be an English text (54,4%).

TryPrediction

The key part of the Python code is in predict.py:

class PredictAPI(webapp.RequestHandler):
  '''This class handles Ajax prediction requests, i.e. not user initiated
     web sessions but remote procedure calls initiated from the Javascript
     client code running the browser.
  '''

  def get(self):
    try:
      # Read server-side OAuth 2.0 credentials from datastore and
      # raise an exception if credentials not found.
      credentials = StorageByKeyName(CredentialsModel, USER_AGENT, 
                                    'credentials').locked_get()
      if not credentials or credentials.invalid:
        raise Exception('missing OAuth 2.0 credentials')

      # Authorize HTTP session with server credentials and obtain  
      # access to prediction API client library.
      http = credentials.authorize(httplib2.Http())
      service = build('prediction', 'v1.4', http=http)
      papi = service.trainedmodels()

      # Read and parse JSON model description data.
      models = parse_json_file(MODELS_FILE)

      # Get reference to user's selected model.
      model_name = self.request.get('model')
      model = models[model_name]

      # Build prediction data (csvInstance) dynamically based on form input.
      vals = []
      for field in model['fields']:
        label = field['label']
        val = str(self.request.get(label))
        vals.append(val)
      body = {'input' : {'csvInstance' : vals }}
      logging.info('model:' + model_name + ' body:' + str(body))

      # Make a prediction and return JSON results to Javascript client.
      ret = papi.predict(id=model['model_id'], body=body).execute()
      self.response.out.write(json.dumps(ret))

    except Exception, err:
      # Capture any API errors here and pass response from API back to
      # Javascript client embedded in a special error indication tag.
      err_str = str(err)
      if err_str[0:len(ERR_TAG)] != ERR_TAG:
        err_str = ERR_TAG + err_str + ERR_END
      self.response.out.write(err_str)

The Java version of Prediction web application is as follows:

public class PredictServlet extends HttpServlet {

  @Override
  protected void doGet(HttpServletRequest request,
                       HttpServletResponse response) throws ServletException, 
                                                            IOException {
    Entity credentials = null;
    try {
      // Retrieve server credentials from app engine datastore.
      DatastoreService datastore = 
        DatastoreServiceFactory.getDatastoreService();
      Key credsKey = KeyFactory.createKey("Credentials", "Credentials");
      credentials = datastore.get(credsKey);
    } catch (EntityNotFoundException ex) {
      // If can't obtain credentials, send exception back to Javascript client.
      response.setContentType("text/html");
      response.getWriter().println("exception: " + ex.getMessage());
    }

    // Extract tokens from retrieved credentials.
    AccessTokenResponse tokens = new AccessTokenResponse();
    tokens.accessToken = (String) credentials.getProperty("accessToken");
    tokens.expiresIn = (Long) credentials.getProperty("expiresIn");
    tokens.refreshToken = (String) credentials.getProperty("refreshToken");
    String clientId = (String) credentials.getProperty("clientId");
    String clientSecret = (String) credentials.getProperty("clientSecret");
    tokens.scope = IndexServlet.scope;

    // Set up the HTTP transport and JSON factory
    HttpTransport httpTransport = new NetHttpTransport();
    JsonFactory jsonFactory = new JacksonFactory();

    // Get user requested model, if specified.
    String model_name = request.getParameter("model");

    // Parse model descriptions from models.json file.
    Map models = 
      IndexServlet.parseJsonFile(IndexServlet.modelsFile);

    // Setup reference to user specified model description.
    Map selectedModel = 
      (Map) models.get(model_name);
    
    // Obtain model id (the name under which model was trained), 
    // and iterate over the model fields, building a list of Strings
    // to pass into the prediction request.
    String modelId = (String) selectedModel.get("model_id");
    List params = new ArrayList();
    List<Map > fields = 
      (List<Map >) selectedModel.get("fields");
    for (Map field : fields) {
      // This loop is populating the input csv values for the prediction call.
      String label = field.get("label");
      String value = request.getParameter(label);
      params.add(value);
    }

    // Set up OAuth 2.0 access of protected resources using the retrieved
    // refresh and access tokens, automatically refreshing the access token 
    // whenever it expires.
    GoogleAccessProtectedResource requestInitializer = 
      new GoogleAccessProtectedResource(tokens.accessToken, httpTransport, 
                                        jsonFactory, clientId, clientSecret, 
                                        tokens.refreshToken);

    // Now populate the prediction data, issue the API call and return the
    // JSON results to the Javascript AJAX client.
    Prediction prediction = new Prediction(httpTransport, requestInitializer, 
                                           jsonFactory);
    Input input = new Input();
    InputInput inputInput = new InputInput();
    inputInput.setCsvInstance(params);
    input.setInput(inputInput);
    Output output = 
      prediction.trainedmodels().predict(modelId, input).execute();
    response.getWriter().println(output.toPrettyString());
  }
}

Besides Python and Java support, Google also offers .NET, Objective-C, Ruby, Go, JavaScript, PHP, etc. libraries for Prediction API.

Google BigQuery


This time I write about Google BigQuery, a service that Google have made publicly available in May, 2012. It was around for some time, some Google Research blog talked about it in 2010, then Google have announced a limited preview in November, 2011 and eventually it went live this month.

The technology is based on Dremel, not MapReduce. The reason for having an alternative to MapReduce is described in the Dremel paper: “Dremel can execute many queries over such data that would ordinarily require a sequence of MapReduce … jobs, but at a fraction of the execution time. Dremel is not intended as a replacement for MR and is often used in conjunction with it to analyze outputs of MR pipelines or rapidly prototype larger computations“.

So what is BigQuery? As it is answered on Google BigQuery website: “Google BigQuery is a web service that lets you do interactive analysis of massive datasets—up to billions of rows.”

Getting Started with BigQuery

In order to be able to use BigQuery, first you need to sign up for it via Google API console. Once that is done, you can start using the service. The easiest way to start with is BigQuery Browser Tool.

BigQuery Browser Tool

When you first login to BigQuery Browser Tool, you see the following welcome message:

There is already a public dataset available, so you can have a quick look around and experience how to use BigQuery Browser Tool.  E.g. here is the schema of github_timeline table, a snapshop from GitHub archive:

You can run a simple query using COMPOSE QUERY from the browser tool, the syntax is SQL-like:

SELECT repository_name, repository_onwer, repository_description FROM publicdata:samples.github_timeline LIMIT 1000;

So far so good… Let us create now our own tables. The dataset that I was using is from WorldBank Data Catalogue and these are GDP and population data for the countries all over the world. These are available in CSV format (as well as Excel and PDF).

As a first step, we need to create the dataset – dataset is basically one or more tables in BigQuery. You need to click on the down-arrow icon, next to the API project and select “Create new dataset”.

Then you need to create the table. Click on the down-arrow for the dataset (worldbank in our case) and select “Create new table”

Then you need to define table parameters such as name, schema and source file to be uploaded. Note: Internet Explorer 8 does not seem to support CSV file upload (“”File upload is not currently supported in your browser.” message occurs for File upload link). You’d better go with Chrome that supports CSV file upload.

When you upload the file, you need to specify the schema in the following format: county_code:string,ranking:integer,country_name:string,value:integer

There are advanced option available, too: you can use e.g tab separated files instead of comma separated ones, you can defined how many invalid rows are accepted, how many rows are skipped, etc.

During the upload, the data is validated against the specified schema, if that is violated, then you will get error messages in the Job history. (e.g. “Too many columns: expected 4 column(s) but got 5 column(s)” )

Once the upload is successfully finished, you are ready to execute queries on the data. You can use COMPOSE QUERY for that, as we have already descibed for the github_timeline table. To display the TOP 10 countries having the highest GDP values, you run the following query:

SELECT country_name, value FROM worldbank.gdp ORDER BY value DESC LIMIT 10

BigQuery Command Line Tool

That was easy but we are hard-core software guys, aren’t we? We need command line, not just browser based functionality! Relax, there is BigQuery command line tool, written in python.

You can download it from here and install it by unzipping the file.

To install it, you just run: python setup.py install

I used BigQuery Command line tool from a Windows 7 machine, the usage is very same on Linux with the exception of where the credentials are stored in your local computer. (that could be ~/.bigquery.v2.token and ~/.bigqueryrc in case of Linux and %USERPROFILE%\.bigquery.v2.token and %USRPROFILE%\.bigqueryrc in case of Windows).

When you run it at the first time it needs to be authenticated via OAuth2.

C:\BigQuery\bigquery-2.0.4>python bq.py shell

******************************************************************
** No OAuth2 credentials found, beginning authorization process **
******************************************************************

Go to the following link in your browser:

    https://accounts.google.com/o/oauth2/auth?scope=https%3A%2F%2Fwww.googleapis
.com%2Fauth%2Fbigquery&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&response
_type=code&client_id=123456789.apps.googleusercontent.com&access_type=offline

Enter verification code: *********
Authentication successful.

************************************************
** Continuing execution of BigQuery operation **
************************************************

Welcome to BigQuery! (Type help for more information.)
BigQuery> ls
   projectId     friendlyName
 -------------- --------------
  190120083879   API Project
BigQuery> exit

So at the first time, you need to go the the given URL with your browser, Allow Access to BigQuery Command Line tool and copy&paste the generated verification code at the  “Enter verification code” prompt. Then it will be stored on your local machine, as mentioned above and you do not need to allow access from then on. (unless you want to initialize the entire access process)

So at the second attempt to run the BigQuery shell it will go flawless without authentication:

C:\BigQuery\bigquery-2.0.4>python bq.py shell
Welcome to BigQuery! (Type help for more information.)
BigQuery> ls
   projectId     friendlyName
 -------------- --------------
  190120083879   API Project
BigQuery> ls 190120083879
  datasetId
 -----------
  worldbank
BigQuery> exit
Goodbye.

To check the schema for GDP and population tables (population table has the same schema as GDP and was also uploaded the same way as GDP- via BigQuery Browser tool):

C:\BigQuery\bigquery-2.0.4>python bq.py show 190120083879:worldbank.gdp
Table 190120083879:worldbank.gdp

   Last modified            Schema            Total Rows   Total Bytes
 ----------------- ------------------------- ------------ -------------
  13 May 12:10:33   |- county_code: string    195          6265
                    |- ranking: integer
                    |- country_name: string
                    |- value: integer

C:\BigQuery\bigquery-2.0.4>python bq.py show 190120083879:worldbank.population
Table 190120083879:worldbank.population

   Last modified            Schema            Total Rows   Total Bytes
 ----------------- ------------------------- ------------ -------------
  13 May 12:14:02   |- county_code: string    215          7007
                    |- ranking: integer
                    |- country_name: string
                    |- value: integer

To check the first 10 rows in population table (you may notice that the values are ordered, it is because that values were already ordered in the WorldBank CSV file):

C:\BigQuery\bigquery-2.0.4>python bq.py head -n 10 190120083879:worldbank.popula
tion
+-------------+---------+--------------------+---------+
| county_code | ranking |    country_name    |  value  |
+-------------+---------+--------------------+---------+
| CHN         |       1 | China              | 1338300 |
| IND         |       2 | India              | 1224615 |
| USA         |       3 | United States      |  309349 |
| IDN         |       4 | Indonesia          |  239870 |
| BRA         |       5 | Brazil             |  194946 |
| PAK         |       6 | Pakistan           |  173593 |
| NGA         |       7 | Nigeria            |  158423 |
| BGD         |       8 | Bangladesh         |  148692 |
| RUS         |       9 | Russian Federation |  141750 |
| JPN         |      10 | Japan              |  127451 |
+-------------+---------+--------------------+---------+

In order to run a SELECT query against a table, first you need to initialize the project, so you have to have the .bigqueryrc properly configured:

C:\Users\istvan>type .bigqueryrc
project_id = 190120083879
credential_file = c:\Users\istvan\.bigquery.v2.token
dataset_id = worldbank
C:\Users\istvan>

Then you can run:

C:\BigQuery\bigquery-2.0.4>python bq.py query "SELECT country_name, value FROM w
orldbank.gdp ORDER BY value DESC LIMIT 10"
Waiting on job_5745d8eb41cf489fbf6ffb7a3bc3487e ... (0s) Current status: RUNNING
Waiting on job_5745d8eb41cf489fbf6ffb7a3bc3487e ... (0s) Current status: DONE

+----------------+----------+
|  country_name  |  value   |
+----------------+----------+
| United States  | 14586736 |
| China          |  5926612 |
| Japan          |  5458837 |
| Germany        |  3280530 |
| France         |  2560002 |
| United Kingdom |  2261713 |
| Brazil         |  2087890 |
| Italy          |  2060965 |
| India          |  1727111 |
| Canada         |  1577040 |
+----------------+----------+

BigQuery API

BigQuery browser tool and command line tool could do in most of the cases. but hell, aren’t we even thougher guys – Master of the APIs?  If yes, Google BigQuery can offer APIs and BigQuery client libraries for us, too. These can be in Python, Java, .NET, PHP, Ruby, Objective-C, etc, etc.

Here is a python application that runs the same SELECT query that we used from browser tool and command line:

import httplib2
import sys
import pprint
from apiclient.discovery import build 
from apiclient.errors import HttpError
from oauth2client.client import AccessTokenRefreshError 
from oauth2client.client import OAuth2WebServerFlow 
from oauth2client.file import Storage
from oauth2client.tools import run 

FLOW = OAuth2WebServerFlow(     
    client_id='123456789.apps.googleusercontent.com',     
    client_secret='*************',     
    scope='https://www.googleapis.com/auth/bigquery',    
    user_agent='bq/2.0')  

	# Run a synchronous query
def runSyncQuery (service, projectId, datasetId, timeout=0):
  try:
    print 'timeout:%d' % timeout
    jobCollection = service.jobs()
    queryData = {'query':'SELECT country_name, value FROM worldbank.gdp ORDER BY value DESC LIMIT 10;',
                 'timeoutMs':timeout}

    queryReply = jobCollection.query(projectId=projectId,
                                     body=queryData).execute()

    jobReference=queryReply['jobReference']

    # Timeout exceeded: keep polling until the job is complete.
    while(not queryReply['jobComplete']):
      print 'Job not yet complete...'
      queryReply = jobCollection.getQueryResults(
                          projectId=jobReference['projectId'],
                          jobId=jobReference['jobId'],
                          timeoutMs=timeout).execute()

    pprint.pprint(queryReply)

  except AccessTokenRefreshError:
    print ("The credentials have been revoked or expired, please re-run"
    "the application to re-authorize")

  except HttpError as err:
    print 'Error in runSyncQuery:', pprint.pprint(err.content)

  except Exception as err:
    print 'Undefined error' % err

def main():    

 # If the credentials don't exist or are invalid, run the native client   
 # auth flow. The Storage object will ensure that if successful the good   
 # credentials will get written back to a file.   

	storage = Storage('c:\Users\istvan\.bigquery.v2.token') # Choose a file name to store the credentials.   
	credentials = storage.get()  

	if credentials is None or credentials.invalid:     
	    credentials = run(FLOW, storage)    

	# Create an httplib2.Http object to handle our HTTP requests and authorize it   
	# with our good credentials.   

	http = httplib2.Http()   
	http = credentials.authorize(http)   
	service = build("bigquery", "v2", http=http)    

	# Now make calls 
	print 'Make call'
	runSyncQuery(service, projectId='190120083879', datasetId='worldbank')

if __name__ == '__main__':
    main(

The output will look like this:

C:\BigQuery\PythonClient>python bq_client.py
Make call
timeout:0
Job not yet complete...
Job not yet complete...
Job not yet complete...
{u'etag': u'"6wEDxP58PwCUv91kOlRB8L7rm_A/69KAOvEhHO4pBtqit7nlzybfIPc"',
 u'jobComplete': True,
 u'jobReference': {u'jobId': u'job_9a1c0d2bcf9443b18e2204d1f4db476a',
                   u'projectId': u'190120083879'},
 u'kind': u'bigquery#getQueryResultsResponse',
 u'rows': [{u'f': [{u'v': u'United States'}, {u'v': u'14586736'}]},
           {u'f': [{u'v': u'China'}, {u'v': u'5926612'}]},
           {u'f': [{u'v': u'Japan'}, {u'v': u'5458837'}]},
           {u'f': [{u'v': u'Germany'}, {u'v': u'3280530'}]},
           {u'f': [{u'v': u'France'}, {u'v': u'2560002'}]},
           {u'f': [{u'v': u'United Kingdom'}, {u'v': u'2261713'}]},
           {u'f': [{u'v': u'Brazil'}, {u'v': u'2087890'}]},
           {u'f': [{u'v': u'Italy'}, {u'v': u'2060965'}]},
           {u'f': [{u'v': u'India'}, {u'v': u'1727111'}]},
           {u'f': [{u'v': u'Canada'}, {u'v': u'1577040'}]}],
 u'schema': {u'fields': [{u'mode': u'NULLABLE',
                          u'name': u'country_name',
                          u'type': u'STRING'},
                         {u'mode': u'NULLABLE',
                          u'name': u'value',
                          u'type': u'INTEGER'}]},
 u'totalRows': u'10'}

C:\BigQuery\PythonClient>

If you want to delve into BigQuery API, here is the link to start.