"""
Storing information on datasets and workflows in local database
"""
import os
import json
import sqlite3
import ast
import logging
# Function to map JSON schema types to SQLite data types
[docs]
def map_json_type_to_sqlite(json_type):
mapping = {
"string": "TEXT",
"integer": "INTEGER",
"number": "REAL",
"boolean": "INTEGER", # SQLite doesn't have a boolean type, typically stored as INTEGER (0 or 1)
"array": "TEXT", # Arrays can be stored as a serialized string (e.g., JSON or CSV)
"object": "TEXT" # Nested objects could also be stored as a serialized string (e.g., JSON)
}
return mapping.get(json_type, "TEXT")
[docs]
class LLMTunerStore:
"""Small DB to hold data sets and processing info for model development"""
def __init__(self, filename = "llmtuner.db", datapath = "local", keepopen = False):
[docs]
self.filename = filename
[docs]
self.datapath = datapath
[docs]
self.keepopen = keepopen
[docs]
self.connection = sqlite3.connect(self.filename)
[docs]
self.cursor = self.connection.cursor()
[docs]
def _open(self):
"""Open connection to the DB"""
self.connection = sqlite3.connect(self.filename)
self.cursor = self.connection.cursor()
if not self.schemas:
filename = self.filename[0:self.filename.rfind("/")]+"/dbschemas.json"
if os.path.isfile(filename):
with open(filename, "r") as file:
self.schemas = json.loads(file.read())
else:
logging.warning("Could not load database schemas from "+ filename)
[docs]
def _close(self, confirm = False):
"""Set confirm = True to force closing of connection"""
if not self.keepopen or confirm:
self.connection.close()
schemafilename = self.filename[0:self.filename.rfind("/")]+"/dbschemas.json"
with open(schemafilename, "w") as file:
file.write(json.dumps(self.schemas))
[docs]
def _keep_open(self, keepopen = True):
"""Set to false if connection should be closed after each operation"""
self.keepopen = keepopen
@staticmethod
[docs]
def _query_create_table_from_schema(schema, table_name = ""):
"""Extract table name and properties"""
if not table_name:
table_name = schema['title']
properties = schema['properties']
required_fields = schema.get('required', [])
# Construct the CREATE TABLE SQL statement
columns = []
for field, attributes in properties.items():
column_definition = f"{field} {map_json_type_to_sqlite(attributes['type'])}"
if field in required_fields:
column_definition += " NOT NULL"
columns.append(column_definition)
sqlcommand = f"CREATE TABLE IF NOT EXISTS {table_name} ({', '.join(columns)});"
return table_name, sqlcommand
@staticmethod
@staticmethod
[docs]
def list_tables(self):
"""Return list of available tables"""
self._open()
self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = self.cursor.fetchall()
return tables
[docs]
def create_table(self, schema, table_name = ""):
"""Creates a DB table from a json schema"""
self._open()
table_name, query = self._query_create_table_from_schema(schema, table_name)
self.cursor.execute(query)
self.connection.commit()
self.schemas[table_name] = schema
print ("Created DB table", table_name)
self._close()
[docs]
def delete_table(self, table_name):
"""Delete a table with 'table_name' from the database"""
self._open()
drop_table_sql = f"DROP TABLE IF EXISTS {table_name};"
self.cursor.execute(drop_table_sql)
self._close()
[docs]
def add_data_to_table(self, table_name, data_dict, uniquekeys = [], update = False, mergekeys = []):
"""Add data from dictionary to a table in sqlite connection.
If 'uniquekey' is provided, the entry is only added if not another entry with the same value
for the key is present in the table. If 'update' is True, an existing entry is updated.
"""
self._open()
data_dict_formatted = self.format_data_to_sql(self.schemas[table_name], data_dict)
if type(uniquekeys) is str:
uniquekeys = [uniquekeys]
unique = True
if uniquekeys:
keyvals = {}
for key in uniquekeys:
keyvals.setdefault(key, data_dict_formatted[key])
founddata = self.get_data_from_table(table_name, key_values = keyvals, asdicts = False)
logging.debug("Looking for prior entries with "+str(keyvals)+". Found "+str(len(founddata))+ " entries.")
if founddata:
unique = False
if update:
self._open()
values = list(data_dict_formatted.values())
set_clause = ", ".join([f"{col} = ?" for col in data_dict_formatted.keys()])
sql_query = f"UPDATE {table_name} SET {set_clause} WHERE "
for key in uniquekeys:
sql_query += key+" = ? AND "
values.append(keyvals[key])
sql_query = sql_query.rstrip(" AND ")
logging.debug("query: "+sql_query)
self.cursor.execute(sql_query, values)
logging.info(f"Entries with {keyvals} already exist. Skipping")
if unique:
self._open()
fields = ', '.join(data_dict_formatted.keys())
placeholders = ', '.join('?' for _ in data_dict_formatted.keys())
values = tuple(data_dict_formatted.values())
# Construct the INSERT INTO SQL statement
insert_sql = f"INSERT INTO {table_name} ({fields}) VALUES ({placeholders})"
# Execute the insert statement
self.cursor.execute(insert_sql, values)
self.connection.commit()
logging.debug("Added: "+insert_sql)
self._close()
[docs]
def delete_data_from_table(self, table_name, key_values = {}):
"""Delete data from table, providing a table name and keys and values of entries to delete as dictionary"""
self._open()
query = "DELETE FROM "+table_name
jointerm = "WHERE "
values = []
for key in key_values:
query = query + " " + jointerm + key + " = '"+key_values[key]+"'"
jointerm = "AND "
query = query+";"
self.cursor.execute(query)
self.connection.commit()
self._close()
[docs]
def get_data_from_table(self, table_name, key_values = {}, asdicts = False):
"""Getting rows from a table, if key_values = {"column1": "value1"} is passed,
only entries matching the criteria are selected.
If the column contains arrays, if value is 'select:XXX', only rows with XXX will be selected."""
self._open()
sql_query = ""
values = []
for key in key_values:
if type(key_values[key]) == list:
placeholders = ", ".join("?" for _ in key_values[key])
sql_query = f"SELECT * FROM {table_name} WHERE {key} IN ({placeholders})"
values = key_values[key]
continue
if not sql_query:
sql_query = "SELECT * FROM "+table_name
jointerm = "WHERE "
for key in key_values:
if type(key_values[key]) is str:
if key_values[key].find("select:")>-1:
key_values[key] = '%"'+key_values[key].lstrip("select:")+'"%'
sql_query = sql_query + " " + jointerm + key + " LIKE ? "
else:
sql_query = sql_query + " " + jointerm + key + " = ? "
jointerm = "AND "
values.append(key_values[key])
logging.debug("Executing query: '"+sql_query+"' with values "+str(values))
self.cursor.execute(sql_query, values)
rows = self.cursor.fetchall()
if asdicts:
column_names = [description[0] for description in self.cursor.description]
result = []
for row in rows:
row_dict = dict(zip(column_names, row))
result.append(row_dict)
self._close()
return result
self._close()
return rows