LLM's in Snowpark

I'm in a competition right now that aims to use LLMs to make medical professionals' lives easier. While trying different models and chains, I became frustrated with how slow my computer was. I decided to try to leverage the compute power of Snowflake. So I loaded the data into Snowflake and wrote a Python UDF that utilizes a model from HuggingFace for question and answering.

Here's the what it does on an example dataset from ChatGPT: Snowflake

Click to learn more! If you to see how I did it, you can read my Medium article, or continue below.

First, I created a question answering pipeline with HuggingFace's SDK.

from transformers import pipeline

model_name = "deepset/tinyroberta-squad2"
classifier = pipeline('question-answering', model=model_name, tokenizer=model_name)

Next, I used joblib to turn the pipeline into a file that I can store in Snowflake. I used the Snowpark "put" method to store it in a Snowflake stage.

# Dump the pipeline
joblib_name = 'squad2-qa.joblib'
stage = 'DATABASE.SCHEMA.QA_PIPLINE'
joblib.dump(classifier, joblib_name)

session.file.put(
   joblib_name,
   stage_location = f'@DATABASE.SCHEMA.QA_PIPLINE',
   overwrite=True,
   auto_compress=False
)

I then created a function to load the model from the joblib file and cache the loaded model.

import cachetools
import sys

@cachetools.cached(cache={})
def read_qa_model():
   import_dir = sys._xoptions.get("snowflake_import_directory")
   if import_dir:
       # Load the model
       return joblib.load(f'{import_dir}/squad2-qa.joblib')

After, I created the Python UDF. It calls the "read_qa_model()" function to read the model from the stage. Any subsequent calls will use the cached model.


@udf(
    name="DATABASE.SCHEMA.make_prediction", 
    session=session,
    is_permanent=True,
    stage_location=f'@DATABASE.SCHEMA.STAGE',
    replace=True,
    imports = ['@DATABASE.SCHEMA.STAGE/squad2-qa.joblib'],
    input_types=[StringType(),StringType()], 
    return_type=StringType(), 
    packages=['snowflake-snowpark-python==1.6.1', 'transformers==4.14.1','cachetools==4.2.2'])

def get_answer(question, context):
    classifier = read_qa_model()
    result = classifier(question, context)

    return result['answer']
© Thomas Smith.