Skip to content

Commit

Permalink
fix[output_format]: accept dataframe dict as output and secure sql qu… (
Browse files Browse the repository at this point in the history
#1432)

* fix[output_format]: accept dataframe dict as output and secure sql query execution

* fix: ruff errors
  • Loading branch information
ArslanSaleem authored Nov 18, 2024
1 parent 719043c commit 437f949
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pandasai/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def execute_direct_sql_query(self, sql_query):
if not self._is_sql_query_safe(sql_query):
raise MaliciousQueryError("Malicious query is generated in code")

return pd.read_sql(sql_query, self._connection)
return pd.read_sql(text(sql_query), self._connection)

@property
def cs_table_name(self):
Expand Down
4 changes: 2 additions & 2 deletions pandasai/helpers/output_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def validate_value(self, expected_type: str) -> bool:
elif expected_type == "string":
return isinstance(self, str)
elif expected_type == "dataframe":
return isinstance(self, (pd.DataFrame, pd.Series))
return isinstance(self, (pd.DataFrame, pd.Series, dict))
elif expected_type == "plot":
if not isinstance(self, (str, dict)):
return False
Expand All @@ -82,7 +82,7 @@ def validate_result(result: dict) -> bool:
elif result["type"] == "string":
return isinstance(result["value"], str)
elif result["type"] == "dataframe":
return isinstance(result["value"], (pd.DataFrame, pd.Series))
return isinstance(result["value"], (pd.DataFrame, pd.Series, dict))
elif result["type"] == "plot":
if "plotly" in repr(type(result["value"])):
return True
Expand Down
12 changes: 12 additions & 0 deletions pandasai/responses/response_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any

import pandas as pd
from PIL import Image

from pandasai.exceptions import MethodNotImplementedError
Expand Down Expand Up @@ -51,9 +52,20 @@ def parse(self, result: dict) -> Any:

if result["type"] == "plot":
return self.format_plot(result)
elif result["type"] == "dataframe":
return self.format_dataframe(result)
else:
return result["value"]

def format_dataframe(self, result: dict) -> Any:
if isinstance(result["value"], dict):
print("Df conversiont")
df = pd.Dataframe(result["value"])
print("Df conversiont Done")
result["value"] = df

return result["value"]

def format_plot(self, result: dict) -> Any:
"""
Display matplotlib plot against a user query.
Expand Down

0 comments on commit 437f949

Please sign in to comment.