Coverage for api/roles/services.py: 93.83%
162 statements
« prev ^ index » next coverage.py v7.9.2, created at 2026-01-25 13:05 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2026-01-25 13:05 +0000
1from typing import Dict, List
2from models.roles import Roles
3from models.role_mapper import RoleMapper
4from core.rbac import check_user_has_super_role
5from sqlalchemy.ext.asyncio import AsyncSession
6from sqlalchemy import select, func, delete, and_
7from models.role_attributes import RoleAttributes
8from models.role_attributes_mapper import RoleAttributesMapper
9from utils.custom_exception import ServerException, ConflictException, NotFoundException
10from .schema import (
11 RoleResponse, RoleCreate, RoleUpdate, RolesListResponse,
12 RoleAttributeMappingBatchResponse, AttributeMappingResult,
13 RoleAttributesGroupedResponse, RoleAttributesGroup, RoleAttributeDetail,
14 PermissionCheckResponse
15)
17async def get_all_roles(db: AsyncSession) -> RolesListResponse:
18 """Get all roles"""
19 try:
20 roles_query = select(Roles).order_by(Roles.name.asc())
21 roles_result = await db.execute(roles_query)
22 roles = roles_result.scalars().all()
24 role_responses = []
25 for role in roles:
26 role_response = RoleResponse(
27 id=role.id,
28 name=role.name,
29 description=role.description
30 )
31 role_responses.append(role_response)
33 return RolesListResponse(roles=role_responses)
35 except Exception as e:
36 raise ServerException(f"Failed to retrieve roles: {str(e)}")
38async def create_role(db: AsyncSession, role_data: RoleCreate) -> RoleResponse:
39 """Create a new role"""
40 try:
41 existing_role = await db.execute(
42 select(Roles).where(Roles.name == role_data.name)
43 )
44 if existing_role.scalar_one_or_none():
45 raise ConflictException("Role name already exists")
47 role = Roles(
48 name=role_data.name,
49 description=role_data.description
50 )
51 db.add(role)
52 await db.commit()
53 await db.refresh(role)
55 return RoleResponse(
56 id=role.id,
57 name=role.name,
58 description=role.description
59 )
61 except ConflictException:
62 raise
63 except Exception as e:
64 raise ServerException(f"Failed to create role: {str(e)}")
66async def update_role(db: AsyncSession, role_id: str, role_data: RoleUpdate) -> RoleResponse:
67 """Update role information"""
68 try:
69 role_result = await db.execute(
70 select(Roles).where(Roles.id == role_id)
71 )
72 role = role_result.scalar_one_or_none()
73 if not role:
74 raise NotFoundException("Role not found")
76 if role_data.name and role_data.name != role.name:
77 existing_role = await db.execute(
78 select(Roles).where(Roles.name == role_data.name, Roles.id != role_id)
79 )
80 if existing_role.scalar_one_or_none():
81 raise ConflictException("Role name already exists")
83 update_data = role_data.model_dump(exclude_unset=True)
84 for field, value in update_data.items():
85 setattr(role, field, value)
87 await db.commit()
88 await db.refresh(role)
90 return RoleResponse(
91 id=role.id,
92 name=role.name,
93 description=role.description
94 )
96 except (ConflictException, NotFoundException):
97 raise
98 except Exception as e:
99 raise ServerException(f"Failed to update role: {str(e)}")
101async def delete_role(db: AsyncSession, role_id: str) -> bool:
102 """Delete a role"""
103 try:
104 role_result = await db.execute(
105 select(Roles).where(Roles.id == role_id)
106 )
107 role = role_result.scalar_one_or_none()
108 if not role:
109 raise NotFoundException("Role not found")
111 user_count = await db.execute(
112 select(func.count(RoleMapper.user_id)).where(RoleMapper.role_id == role_id)
113 )
114 if user_count.scalar() > 0:
115 raise ConflictException("Cannot delete role that is assigned to users")
117 await db.execute(
118 delete(RoleAttributesMapper).where(RoleAttributesMapper.role_id == role_id)
119 )
121 await db.execute(
122 delete(Roles).where(Roles.id == role_id)
123 )
124 await db.commit()
126 return True
128 except (ConflictException, NotFoundException):
129 raise
130 except Exception as e:
131 raise ServerException(f"Failed to delete role: {str(e)}")
133async def get_role_attribute_mapping(db: AsyncSession, role_id: str) -> RoleAttributesGroupedResponse:
134 """Get role attributes mapping grouped by group and category (left join)."""
135 try:
136 role_result = await db.execute(
137 select(Roles).where(Roles.id == role_id)
138 )
139 role = role_result.scalar_one_or_none()
140 if not role:
141 raise NotFoundException("Role not found")
143 # Use LEFT JOIN to get all attributes and their mappings
144 query = select(
145 RoleAttributes.name,
146 RoleAttributes.group,
147 RoleAttributes.category,
148 RoleAttributesMapper.value
149 ).select_from(
150 RoleAttributes
151 ).outerjoin(
152 RoleAttributesMapper,
153 and_(
154 RoleAttributes.id == RoleAttributesMapper.attributes_id,
155 RoleAttributesMapper.role_id == role_id
156 )
157 ).order_by(RoleAttributes.id)
159 result = await db.execute(query)
160 rows = result.all()
162 grouped: Dict[str, Dict[str, List[RoleAttributeDetail]]] = {}
163 for row in rows:
164 group = row.group or "default"
165 category = row.category or "uncategorized"
166 grouped.setdefault(group, {}).setdefault(category, []).append(
167 RoleAttributeDetail(
168 name=row.name,
169 value=row.value if row.value is not None else False,
170 )
171 )
173 groups = []
174 for group, categories in grouped.items():
175 groups.append(
176 RoleAttributesGroup(
177 group=group,
178 categories=categories,
179 )
180 )
182 return RoleAttributesGroupedResponse(groups=groups)
184 except NotFoundException:
185 raise
186 except Exception as e:
187 raise ServerException(f"Failed to get role attributes: {str(e)}")
189async def update_role_attribute_mapping(db: AsyncSession, role_id: str, attributes_data: Dict[str, bool]) -> RoleAttributeMappingBatchResponse:
190 """Batch update role and attributes mapping with detailed results"""
191 try:
192 role_result = await db.execute(
193 select(Roles).where(Roles.id == role_id)
194 )
195 role = role_result.scalar_one_or_none()
196 if not role:
197 raise NotFoundException("Role not found")
199 results = []
200 success_count = 0
201 failed_count = 0
203 # Get mapping from attribute names to IDs
204 attribute_names = list(attributes_data.keys())
205 name_to_id_map = {}
206 invalid_names = set()
208 if attribute_names:
209 existing_attributes = await db.execute(
210 select(RoleAttributes.id, RoleAttributes.name).where(RoleAttributes.name.in_(attribute_names))
211 )
212 for row in existing_attributes:
213 name_to_id_map[row.name] = row.id
215 # Handle invalid attribute names
216 invalid_names = set(attribute_names) - set(name_to_id_map.keys())
217 for invalid_name in invalid_names:
218 results.append(AttributeMappingResult(
219 attribute_id=invalid_name, # Keep name for error reporting
220 status="failed",
221 message="Invalid attribute name"
222 ))
223 failed_count += 1
225 # Process valid attributes
226 for attribute_name, value in attributes_data.items():
227 if attribute_name in invalid_names:
228 continue
230 attribute_id = name_to_id_map.get(attribute_name)
231 if not attribute_id:
232 continue
234 try:
235 existing_mapping = await db.execute(
236 select(RoleAttributesMapper).where(
237 and_(
238 RoleAttributesMapper.role_id == role_id,
239 RoleAttributesMapper.attributes_id == attribute_id
240 )
241 )
242 )
243 mapping = existing_mapping.scalar_one_or_none()
245 if mapping:
246 # Update existing mapping
247 mapping.value = value
248 else:
249 # Create new mapping
250 new_mapping = RoleAttributesMapper(
251 role_id=role_id,
252 attributes_id=attribute_id,
253 value=value
254 )
255 db.add(new_mapping)
257 results.append(AttributeMappingResult(
258 attribute_id=attribute_name,
259 status="success",
260 message="Updated successfully"
261 ))
262 success_count += 1
264 except Exception as e:
265 results.append(AttributeMappingResult(
266 attribute_id=attribute_name,
267 status="failed",
268 message=f"Failed to process: {str(e)}"
269 ))
270 failed_count += 1
272 await db.commit()
274 return RoleAttributeMappingBatchResponse(
275 results=results,
276 total_attributes=len(attributes_data),
277 success_count=success_count,
278 failed_count=failed_count
279 )
281 except NotFoundException:
282 raise
283 except Exception as e:
284 raise ServerException(f"Failed to update role attributes mapping: {str(e)}")
286async def check_user_permissions(
287 db: AsyncSession,
288 user_id: str,
289 required_attributes: List[str] = None
290) -> PermissionCheckResponse:
291 """Check if user has required permission attributes."""
292 try:
293 # Get all available attributes
294 all_attributes_result = await db.execute(select(RoleAttributes.name))
295 all_attributes = [row[0] for row in all_attributes_result.fetchall()]
297 if await check_user_has_super_role(user_id, db):
298 if required_attributes:
299 permissions = {attr: True for attr in required_attributes}
300 else:
301 permissions = {attr: True for attr in all_attributes}
302 return PermissionCheckResponse(permissions=permissions)
304 user_role_query = select(RoleMapper.role_id).where(RoleMapper.user_id == user_id)
305 user_role_result = await db.execute(user_role_query)
306 user_role_id = user_role_result.scalar_one_or_none()
308 user_attributes_set = set()
309 if user_role_id:
310 attributes_query = select(RoleAttributes.name).join(
311 RoleAttributesMapper,
312 RoleAttributes.id == RoleAttributesMapper.attributes_id
313 ).where(
314 and_(
315 RoleAttributesMapper.role_id == user_role_id,
316 RoleAttributesMapper.value == True
317 )
318 )
319 attributes_result = await db.execute(attributes_query)
320 user_attributes_set = {row[0] for row in attributes_result.fetchall()}
322 if not required_attributes:
323 permissions = {attr: attr in user_attributes_set for attr in all_attributes}
324 else:
325 permissions = {attr: attr in user_attributes_set for attr in required_attributes}
327 return PermissionCheckResponse(permissions=permissions)
329 except Exception as e:
330 raise ServerException(f"Failed to check user permissions: {str(e)}")