LLM's in Snowpark
I'm in a competition right now that aims to use LLMs to make medical professionals' lives easier. While trying different models and chains, I became frustrated with how slow my computer was. I decided to try to leverage the compute power of Snowflake. So I loaded the data into Snowflake and wrote a Python UDF that utilizes a model from HuggingFace for question and answering.
Here's the what it does on an example dataset from ChatGPT:
Click to learn more! If you to see how I did it, you can read my Medium article, or continue below.
First, I created a question answering pipeline with HuggingFace's SDK.
from transformers import pipeline
model_name = "deepset/tinyroberta-squad2"
classifier = pipeline('question-answering', model=model_name, tokenizer=model_name)
Next, I used joblib to turn the pipeline into a file that I can store in Snowflake. I used the Snowpark "put" method to store it in a Snowflake stage.
# Dump the pipeline
joblib_name = 'squad2-qa.joblib'
stage = 'DATABASE.SCHEMA.QA_PIPLINE'
joblib.dump(classifier, joblib_name)
session.file.put(
joblib_name,
stage_location = f'@DATABASE.SCHEMA.QA_PIPLINE',
overwrite=True,
auto_compress=False
)
I then created a function to load the model from the joblib file and cache the loaded model.
import cachetools
import sys
@cachetools.cached(cache={})
def read_qa_model():
import_dir = sys._xoptions.get("snowflake_import_directory")
if import_dir:
# Load the model
return joblib.load(f'{import_dir}/squad2-qa.joblib')
After, I created the Python UDF. It calls the "read_qa_model()" function to read the model from the stage. Any subsequent calls will use the cached model.
@udf(
name="DATABASE.SCHEMA.make_prediction",
session=session,
is_permanent=True,
stage_location=f'@DATABASE.SCHEMA.STAGE',
replace=True,
imports = ['@DATABASE.SCHEMA.STAGE/squad2-qa.joblib'],
input_types=[StringType(),StringType()],
return_type=StringType(),
packages=['snowflake-snowpark-python==1.6.1', 'transformers==4.14.1','cachetools==4.2.2'])
def get_answer(question, context):
classifier = read_qa_model()
result = classifier(question, context)
return result['answer']
© Thomas Smith.